{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Character-level text generation with LSTM\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2015/06/15<br>\n",
    "**Last modified:** 2020/04/30<br>\n",
    "**Description:** Generate text from Nietzsche's writings with a character-level LSTM."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "This example demonstrates how to use a LSTM model to generate\n",
    "text character-by-character.\n",
    "\n",
    "At least 20 epochs are required before the generated text\n",
    "starts sounding locally coherent.\n",
    "\n",
    "It is recommended to run this script on GPU, as recurrent\n",
    "networks are quite computationally intensive.\n",
    "\n",
    "If you try this script on new data, make sure your corpus\n",
    "has at least ~100k characters. ~1M is better.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import io\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare the data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "path = keras.utils.get_file(\n",
    "    \"nietzsche.txt\", origin=\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\"\n",
    ")\n",
    "with io.open(path, encoding=\"utf-8\") as f:\n",
    "    text = f.read().lower()\n",
    "text = text.replace(\"\\n\", \" \")  # We remove newlines chars for nicer display\n",
    "print(\"Corpus length:\", len(text))\n",
    "\n",
    "chars = sorted(list(set(text)))\n",
    "print(\"Total chars:\", len(chars))\n",
    "char_indices = dict((c, i) for i, c in enumerate(chars))\n",
    "indices_char = dict((i, c) for i, c in enumerate(chars))\n",
    "\n",
    "# cut the text in semi-redundant sequences of maxlen characters\n",
    "maxlen = 40\n",
    "step = 3\n",
    "sentences = []\n",
    "next_chars = []\n",
    "for i in range(0, len(text) - maxlen, step):\n",
    "    sentences.append(text[i : i + maxlen])\n",
    "    next_chars.append(text[i + maxlen])\n",
    "print(\"Number of sequences:\", len(sentences))\n",
    "\n",
    "x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)\n",
    "y = np.zeros((len(sentences), len(chars)), dtype=np.bool)\n",
    "for i, sentence in enumerate(sentences):\n",
    "    for t, char in enumerate(sentence):\n",
    "        x[i, t, char_indices[char]] = 1\n",
    "    y[i, char_indices[next_chars[i]]] = 1\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build the model: a single LSTM layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(shape=(maxlen, len(chars))),\n",
    "        layers.LSTM(128),\n",
    "        layers.Dense(len(chars), activation=\"softmax\"),\n",
    "    ]\n",
    ")\n",
    "optimizer = keras.optimizers.RMSprop(learning_rate=0.01)\n",
    "model.compile(loss=\"categorical_crossentropy\", optimizer=optimizer)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare the text sampling function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def sample(preds, temperature=1.0):\n",
    "    # helper function to sample an index from a probability array\n",
    "    preds = np.asarray(preds).astype(\"float64\")\n",
    "    preds = np.log(preds) / temperature\n",
    "    exp_preds = np.exp(preds)\n",
    "    preds = exp_preds / np.sum(exp_preds)\n",
    "    probas = np.random.multinomial(1, preds, 1)\n",
    "    return np.argmax(probas)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "epochs = 40\n",
    "batch_size = 128\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    model.fit(x, y, batch_size=batch_size, epochs=1)\n",
    "    print()\n",
    "    print(\"Generating text after epoch: %d\" % epoch)\n",
    "\n",
    "    start_index = random.randint(0, len(text) - maxlen - 1)\n",
    "    for diversity in [0.2, 0.5, 1.0, 1.2]:\n",
    "        print(\"...Diversity:\", diversity)\n",
    "\n",
    "        generated = \"\"\n",
    "        sentence = text[start_index : start_index + maxlen]\n",
    "        print('...Generating with seed: \"' + sentence + '\"')\n",
    "\n",
    "        for i in range(400):\n",
    "            x_pred = np.zeros((1, maxlen, len(chars)))\n",
    "            for t, char in enumerate(sentence):\n",
    "                x_pred[0, t, char_indices[char]] = 1.0\n",
    "            preds = model.predict(x_pred, verbose=0)[0]\n",
    "            next_index = sample(preds, diversity)\n",
    "            next_char = indices_char[next_index]\n",
    "            sentence = sentence[1:] + next_char\n",
    "            generated += next_char\n",
    "\n",
    "        print(\"...Generated: \", generated)\n",
    "        print()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "lstm_character_level_text_generation",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}